################################
## Script: 05_prep_gpt4o.R
## Purpose: This code prepares the cross encoder scores
## for gpt4o. It also tunes the cross encoder cutoff. 
## Data In:
## 1) recall master set
## data/recall_master_updated_11_24_2024_public.rds
## 2) recall held out set (only used to merge in data, not used in tuning)
## data/recall_holdout_combined_public.rds
## 2) Sbert cross encoder files
## data/cross_encoder/cross_encoder_bioweapons_cosine_sts_{num}.csv
## 3) bioweapons files with sbert embeddings
## data/bioweapons_casestudy_5_20_2024_sbert_embeddings.json
## 4) claim, subject annotations
## data/gpt4o_annotations/annotate_lists_{num}.rds
## Data Out:
## 1) precision coding files for sbert:
## a) round 1 (lower cutoff):
## data/precision_sbert_master.rds
## data/precision_sbert_coding_ra_1.csv
## data/precision_sbert_coding_ra_2.csv
## b) round 2 
##  data/precision_sbert_master_recall_cutoff.rds
##  data/precision_sbert_coding_recall_cutoff_ra_2.csv
## data/precision_sbert_coding_recall_cutoff_ra_1.csv
## 2) recall files with sbert cross score included (for calculating recall in later file):
## data/recall_master_updated_11_24_2024_with_sbert_cross_score_public.rds
## data/recall_holdout_combined_with_sbert_cross_score_public.rds
## 3) File to pass through llm annotators:
## data/gpt4o_toannotate.rds
## Notes: 
## We comment out code to write out files
## so that replicators can run script without
## having to worry about saving over files 
## but still understand which code produced which output. 


library(tidyverse)
library(jsonlite)

###########################
## Dependencies ###########
###########################

## true positive hand coded
recall <- readRDS("data/recall_master_updated_11_24_2024_public.rds")

## creating unique match IDs for each pair that will allow 
## us to merge data regardless of the order
## of the pairs (whether article is ego 
## or alter), we do so by creating an ID
## where the articles in each pair are listed in
## alphabetical order
mat <- which(colnames(recall) == "article_id" |
               colnames(recall) == "focal_article")
recall$match_id <- apply(recall, 1, function(x){
  x <- x[mat]
  x <- x[order(x)]
  x <- paste0(x, collapse = "_")
  return(x)
})

## holdout recall - we don't tune cross-encoder
## score on this
## but merge in data for calculating out of sample statistics downstream
recall_holdout <- readRDS("data/recall_holdout_combined_public.rds")


## sbert cross encoder scores
files <- list.files("data/cross_encoder/")
files <- files[grepl("cross_encoder_bioweapons_cosine", files)]
files <- paste0("data/cross_encoder/",
                files)

out <- list()
for(i in 1:length(files)){
  out[[i]] <- read.csv(files[i])
}
out <- bind_rows(out)

## creating match ID for merging
mat <- which(colnames(out) == "ego_id" |
               colnames(out) == "alter_id")
out$match_id <- apply(out, 1, function(x){
  x <- x[mat]
  x <- x[order(x)]
  x <- paste0(x, collapse = "_")
  return(x)
})

## total articles 
total_art <- jsonlite::stream_in(file("data/bioweapons_casestudy_5_20_2024_sbert_embeddings.json"))

## list of claims and arguments
list_annotations <- list.files("data/gpt4o_annotations/")
list_annotations <- list_annotations[grepl("list", list_annotations)]
list_annotations_num <- gsub("annotate_lists_|\\.rds", "", list_annotations)
list_annotations_num <- as.numeric(list_annotations_num)

list_annotations <- paste0("data/gpt4o_annotations/",
                           list_annotations)

## read in lists
## we read in the files this way because
## when we ran the annotations we only saved the output
## in a list with two lists
## the first list included the output with the subject
## lists, the second with the output from the claims lists
## we split the articles in the same way as we did for the original
## annotation so the data can be read back in 

total_art <- split(total_art, 1:100)

for(i in 1:100){
  index <- which(list_annotations_num == i)
  file_toread <- list_annotations[index]
  file_toread <- readRDS(file_toread)
  subject <- file_toread[[1]]
  subject <- unlist(lapply(subject, function(x){return(x$choices$message.content)}))
  
  claim <- file_toread[[2]]
  claim <- unlist(lapply(claim, function(x){return(x$choices$message.content)}))
  
  total_art[[i]]$subject_list <- subject
  total_art[[i]]$claim_list <- claim
}

total_art <- bind_rows(total_art)


#######################################
#### Creating Precision Coding File ###
#######################################

## in this section we create precision coding
## files so we can test the precision of SBERT
## as a standalone classifier 

## whether in 5 day date window
out$date_window <- as.Date(out$ego_date) >= as.Date(out$alter_date) - days(5) &
  as.Date(out$ego_date) <= as.Date(out$alter_date) + days(5)

## V1: by cross score quartile
quartiles <- quantile(out$cross_score)
out$cross_score_quartile <- ifelse(out$cross_score < quartiles[2],
                                   "1st Quartile",
                                   ifelse(out$cross_score < quartiles[3],
                                          "2nd Quartile",
                                          ifelse(out$cross_score < quartiles[4],
                                                 "3rd Quartile",
                                                 "4th Quartile")))
table(out$cross_score_quartile)
## approx. ~98080 in each bucket

precision_sbert <- split(out, out$cross_score_quartile)
set.seed(97405)
for(i in 1:length(precision_sbert)){
  samp <- nrow(precision_sbert[[i]])
  samp <- sample(1:samp,
                 30, replace = FALSE)
  precision_sbert[[i]] <- precision_sbert[[i]][samp, ]
}
precision_sbert <- bind_rows(precision_sbert)
precision_sbert <- precision_sbert[sample(1:nrow(precision_sbert),
                                          nrow(precision_sbert),
                                          replace = FALSE), ]

## each file should be coded by three RAs
precision_sbert$coder <- c(rep("RA1", 60),
                           rep("RA2", 60))

#saveRDS(precision_sbert, "data/precision_sbert_master.rds")

## creating coding files
precision_sbert$same_subject <- NA
precision_sbert$same_claim <- NA

#write.csv(precision_sbert[precision_sbert$coder == "RA1", c("ego_id",
#                             "alter_id",
#                             "ego_summary",
#                          "alter_summary",
#                            "same_subject",
#                           "same_claim")], "data/precision_sbert_coding_ra_1.csv")


#write.csv(precision_sbert[precision_sbert$coder == "RA2", c("ego_id",
 #                                                          "alter_id",
#                                                              "ego_summary",
#                                                              "alter_summary",
#                                                              "same_subject",
#                                                              "same_claim")], "data/precision_sbert_coding_ra_2.csv")


## v2: 2nd precision pass
## We investigate a smaller range of cutoffs
## all above the minimum threshold 
## identified in the recall set (.48)

## identifying distribution of cross scores
## in recall set
out$recall <- ifelse(out$match_id %in% recall$match_id,
                     "Recall Set", "Not\nRecall Set")
quartiles_recall <- quantile(out$cross_score[out$recall == "Recall Set"])


## setting those cutoffs for full
## set of pairs
out$cross_score_quartile_recall <- ifelse(out$cross_score < quartiles_recall[1],
                                   "Below Recall Set",
                                   ifelse(out$cross_score < quartiles_recall[2],
                                          "1st Quartile Recall",
                                          ifelse(out$cross_score < quartiles_recall[3],
                                                 "2nd Quartile Recall",
                                                 ifelse(out$cross_score < quartiles_recall[4],
                                                        "3rd Quartile Recall",
                                                        "4th Quartile Recall"))))

## counts of pairs in each sbert recall quartile
## (needed for overall precision calculation)
counts_cross_score <- out %>%
  group_by(cross_score_quartile_recall) %>%
  summarize(n = n())
 
## merging in quartile data into recall set
recall$out_cross_score_quartile_recall <- out$cross_score_quartile_recall[
  match(recall$match_id,
        out$match_id)
]

## merge in cross score recall data for holdout
## note that we didn't tune cross threshold cutoff below 
## based on these data, we only merge the data here 

recall_holdout$out_cross_score_quartile_recall <- out$cross_score_quartile_recall[
  match(recall_holdout$match_id,
        out$match_id)
]
 
## take sample of pairs within
## each quartile
precision_sbert <- split(out, out$cross_score_quartile_recall)
set.seed(97405)
for(i in 1:length(precision_sbert)){
   samp <- nrow(precision_sbert[[i]])
   samp <- sample(1:samp,
                  30, replace = FALSE)
   precision_sbert[[i]] <- precision_sbert[[i]][samp, ]
}
precision_sbert <- bind_rows(precision_sbert)

precision_sbert <- precision_sbert[sample(1:nrow(precision_sbert),
                                           nrow(precision_sbert),
                                           replace = FALSE), ]
 
precision_sbert <- precision_sbert %>%
   filter(cross_score_quartile_recall != "Below Recall Set")
 
 

## each file should be coded by three RAs
precision_sbert$coder <- c(rep("RA1", 60),
                            rep("RA2", 60))
 

# for master file, add in denominator details
## we will use to calculate precision
precision_sbert$cross_score_quartile_recall_denominator_counts <-
  counts_cross_score$n[
    match(precision_sbert$cross_score_quartile_recall,
          counts_cross_score$cross_score_quartile_recall)
  ]

#saveRDS(precision_sbert, "data/precision_sbert_master_recall_cutoff.rds")
 

## creating coding files
precision_sbert$same_subject <- NA
precision_sbert$same_claim <- NA
 

#write.csv(precision_sbert[precision_sbert$coder == "RA1", c("ego_id",
#                             "alter_id",
#                             "ego_summary",
#                         "alter_summary",
#                           "same_subject",
#                          "same_claim")], "data/precision_sbert_coding_recall_cutoff_ra_1.csv")
 
 
#write.csv(precision_sbert[precision_sbert$coder == "RA2", c("ego_id",
#                                                            "alter_id",
#                                                              "ego_summary",
#                                                             "alter_summary",
#                                                             "same_subject",
#                                                             "same_claim")], "data/precision_sbert_coding_recall_cutoff_ra_2.csv")
 

## save recall files with sbert cutoff information 
#saveRDS(recall,
#        "data/recall_master_updated_11_24_2024_with_sbert_cross_score_public.rds")

#saveRDS(recall_holdout,
#        "data/recall_holdout_combined_with_sbert_cross_score_public.rds")

#################################
### Tune Cross Encoder Cutoff ###
#################################

## we tun cross encoder cutoff purely based on
## recall (non-held out) data

## We set a threshold of .5 for the cross encoder
## calculating recall of recall training set 
## at that threshold 
sum(recall$match_id %in% out$match_id[out$cross_score > .5]) ## 116
116 / 121
## what count/percent of pairs remain 
## after we limit to that threshold?
sum(out$cross_score > .5) / nrow(out)

## We visualize the distribution 
## of cross encoder scores by whether
## the pair is in the recall set,
## and visualize our .5 cross score cutoff
## This is included in the SI, Figure A7
ggplot(out,
       mapping = aes(x = cross_score,
                     fill = recall)) +
  geom_density(alpha = .5) +
  theme_bw() +
  labs(x = "Sbert Cross Encoder Score", 
       y = "Density",
       fill = NULL,
       title = "Candidate Pairs Step 2: Cross Encoder Cutoff") +
  geom_vline(mapping = aes(xintercept = .5),
             lty = 2) +
  annotate(geom = "label",
           label = "Setting cross encoder\ncutoff at .5 returns 95.8% of\n positive cases,\ndiscards remaining\n44.4% of candidate pairs",
           x = .25,
           y =5)
## figures/cross_encoder_recall, 5x7


#######################################
### Creating File for GPT4o Annotation#
#######################################

## Here we create the file we pass
## to our GPT4o annotator. We include
## all pairs with cross encoder score greater than
## .5

out$ego_claim <- total_art$claim_list[match(out$ego_id,
                                            total_art$article_id)]
out$alter_claim <- total_art$claim_list[match(out$alter_id,
                                            total_art$article_id)]

out$ego_subject <- total_art$subject_list[match(out$ego_id,
                                            total_art$article_id)]
out$alter_subject <- total_art$subject_list[match(out$alter_id,
                                              total_art$article_id)]

out <- out %>%
  filter(cross_score >= .5 &
           date_window == TRUE)
nrow(out) ## 64,677 to annotate

out <- out %>%
  dplyr::select(-cross_score_quartile,
                -cross_score_quartile_recall)

#saveRDS(out, "data/gpt4o_toannotate.rds")
